NOTE: At the time of running this notebook, we were running the grid components in background mode.
Components:
This notebook was made based on Part 10: Federated Learning with Encrypted Gradient Aggregation tutorial
In [ ]:
import syft as sy
from syft.grid.clients.dynamic_fl_client import DynamicFLClient
import torch
import pickle
import time
import torchvision
from torchvision import datasets, transforms
import tqdm
In [ ]:
hook = sy.TorchHook(torch)
# Connect directly to grid nodes
nodes = ["ws://localhost:3000/",
"ws://localhost:3001/"]
compute_nodes = []
for node in nodes:
compute_nodes.append( DynamicFLClient(hook, node) )
In [ ]:
N_SAMPLES = 10000
MNIST_PATH = './dataset'
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
trainset = datasets.MNIST(MNIST_PATH, download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=N_SAMPLES, shuffle=False)
dataiter = iter(trainloader)
images_train_mnist, labels_train_mnist = dataiter.next()
In [ ]:
datasets_mnist = torch.split(images_train_mnist, int(len(images_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)
labels_mnist = torch.split(labels_train_mnist, int(len(labels_train_mnist) / len(compute_nodes)), dim=0 ) #tuple of chunks (labels / number of nodes)
In [ ]:
tag_img = []
tag_label = []
for i in range(len(compute_nodes)):
tag_img.append(datasets_mnist[i].tag("#X", "#mnist", "#dataset").describe("The input datapoints to the MNIST dataset."))
tag_label.append(labels_mnist[i].tag("#Y", "#mnist", "#dataset").describe("The input labels to the MNIST dataset."))
In [ ]:
shared_x1 = tag_img[0].send(compute_nodes[0]) # First chunk of dataset to Bob
shared_x2 = tag_img[1].send(compute_nodes[1]) # Second chunk of dataset to Alice
shared_y1 = tag_label[0].send(compute_nodes[0]) # First chunk of labels to Bob
shared_y2 = tag_label[1].send(compute_nodes[1]) # Second chunk of labels to Alice
In [ ]:
print("X tensor pointers: ", shared_x1, shared_x2)
print("Y tensor pointers: ", shared_y1, shared_y2)
In [ ]:
for i in range(len(compute_nodes)):
compute_nodes[i].close()